Skip to content

Enable CPU/XPU native and ipex path #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
May 28, 2025

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented May 8, 2025

This PR enables ipex and other optimizations including:

  1. ipex fused op
  2. enable fp4 on cpu
  3. enable has_rem on quantize/dequantize 4bit
  4. Simple 8bit matmul so can make finetune faster on CPU

Also, it fixed the parameter patch for cpu.

It could pass all transformers tests

After this PR merged, I will update the installation guide.

@matthewdouglas @Titus-von-Koeller

Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
@jiqing-feng jiqing-feng marked this pull request as ready for review May 8, 2025 07:25
Signed-off-by: jiqing-feng <[email protected]>
Copy link

github-actions bot commented May 8, 2025

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng
Copy link
Contributor Author

I am cleaning the CPU and XPU tests, process 50%

quant_state.blocksize,
quant_state.shape,
quant_state.dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there reason why this change can't be in bitsandbytes/backends/cpu/ops.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
@matthewdouglas matthewdouglas added this to the v0.47.0 milestone May 9, 2025
Signed-off-by: jiqing-feng <[email protected]>
@jiqing-feng
Copy link
Contributor Author

Hi @matthewdouglas . Please check if there are anything missed before merging. We can discuss the tests in this issue #1637

@jiqing-feng jiqing-feng mentioned this pull request May 13, 2025
Comment on lines 299 to 303
if not ipex_cpu:
logger.warning(
"The installed version of bitsandbytes was compiled without IPEX support. "
"You can install ipex by running `pip install intel_extension_for_pytorch`to get better performance if you use the Intel CPU.",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like extra noise that we'd want to avoid.

Something to point out is that we still plan to ship libbitsandbytes_cpu in our wheels, so for most users, it's going to load a CPU, CUDA, or eventually ROCm or Metal library and we'll hit this logging line. At most we should really only raise this warning when:

  1. We're on a platform with IPEX CPU support. My understanding is this is limited to Linux x86-64.
  2. We expect the user to be using CPU, i.e. no CUDA, XPU, or MPS accelerators available.
    On torch >= 2.6 we could just use torch.accelerator.is_available() and on older versions I think we can overlook privateuse1 backends like HPU or Ascend NPU.
  3. There's some expectation of IPEX being beneficial. We don't want to prompt users to install it if e.g. it needs AVX512 or AMX support to be effective. This is something I can't speak to directly but defer to Intel folks to determine.

Any other thoughts @Titus-von-Koeller ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were right. I also agree that the log should only exist if no devices like cuda/xpu are available and the CPU is an Intel product.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed it, please review again. Thanks!

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented May 16, 2025

Hi @matthewdouglas . You can see some CPU ops are implemented by pure python, which works for all devices. Could we move these pure python implementations to the default folder? As some of these ops could be reused by XPU.

@jiqing-feng jiqing-feng changed the title Enable ipex and other optimizations Enable CPU/XPU native and ipex path May 21, 2025
@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented May 21, 2025

Hi @matthewdouglas . I moved pure pytorch ops to the default folder and it works well. The CIs are all passed. Please continue to review this PR.

Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
Signed-off-by: jiqing-feng <[email protected]>
@matthewdouglas matthewdouglas merged commit aaa71d7 into bitsandbytes-foundation:main May 28, 2025
41 checks passed
This was referenced Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants